/**
 * @license
 * Copyright 2018 Google LLC
 *
 * Use of this source code is governed by an MIT-style
 * license that can be found in the LICENSE file or at
 * https://opensource.org/licenses/MIT.
 * =============================================================================
 */

import * as tfc from '@tensorflow/tfjs-core';

import {Shape} from '../keras_format/common';
import {TensorOrArrayOrMap} from '../types';

import {Dataset, LazyIterator} from './dataset_stub';
import {FitDatasetElement} from './training_dataset';

export interface FakeDatasetArgs {
  /**
   * The shape(s) of the features of a single example.
   *
   * Use an object mapping name to shape, if more than one feature tensors
   * are required.
   */
  xShape: Shape|{[name: string]: Shape};

  /**
   * The shape of the target(s) of a single exapmle.
   */
  yShape: Shape|{[name: string]: Shape};

  /**
   * A function that generates preset sequence of X tensors.
   *
   * This function is invoked each time a new iterator is created.
   */
  xTensorsFunc?: () => tfc.Tensor[] | {[name: string]: tfc.Tensor[]};

  /**
   * A function that generates preset sequence of Y tensors.
   *
   * This function is invoked each time a new iterator is created.
   */
  yTensorsFunc?: () => tfc.Tensor[] | {[name: string]: tfc.Tensor[]};

  /**
   * The size of each batch generated by the iterator.
   */
  batchSize: number;

  /**
   * The number of batches an iterator generates before declaring done to be
   * true.
   */
  numBatches: number;
}

function mergeBatchSizeAndShape(
    batchSize: number, shape: Shape|{[name: string]: Shape}): Shape|
    {[name: string]: Shape} {
  if (Array.isArray(shape)) {
    return [batchSize].concat(shape);
  } else {
    const output: {[name: string]: Shape} = {};
    for (const name in shape) {
      output[name] = [batchSize].concat(shape[name]);
    }
    return output;
  }
}

function generateRandomTensorContainer(shape: Shape|{[name: string]: Shape}):
    tfc.Tensor|{[name: string]: tfc.Tensor} {
  let output: tfc.Tensor|{[name: string]: tfc.Tensor};
  if (Array.isArray(shape)) {
    output = tfc.randomNormal(shape);
  } else {
    output = {};
    for (const name in shape) {
      output[name] = tfc.randomNormal(shape[name]);
    }
  }
  return output;
}

class FakeNumericIterator extends LazyIterator<FitDatasetElement> {
  private xBatchShape: Shape|{[name: string]: Shape};
  private yBatchShape: Shape|{[name: string]: Shape};
  private numBatches: number;
  private batchCount: number;
  private xTensorsFunc: () => tfc.Tensor[] | {[name: string]: tfc.Tensor[]};
  private yTensorsFunc: () => tfc.Tensor[] | {[name: string]: tfc.Tensor[]};
  private xTensorValues: tfc.Tensor[]|{[name: string]: tfc.Tensor[]};
  private yTensorValues: tfc.Tensor[]|{[name: string]: tfc.Tensor[]};
  private tensorIndex = 0;

  constructor(args: FakeDatasetArgs) {
    super();
    this.xBatchShape = mergeBatchSizeAndShape(args.batchSize, args.xShape);
    this.yBatchShape = mergeBatchSizeAndShape(args.batchSize, args.yShape);
    this.numBatches = args.numBatches;
    this.batchCount = 0;
    this.xTensorsFunc = args.xTensorsFunc;
    this.yTensorsFunc = args.yTensorsFunc;

    // Sanity check on the preset tensors.
    tfc.util.assert(
        this.xTensorsFunc == null && this.yTensorsFunc == null ||
            this.xTensorsFunc != null && this.yTensorsFunc != null,
        () => 'presetXTensors and presetYTensors must be both null/undefined ' +
            'or both set.');
  }

  async next(): Promise<IteratorResult<FitDatasetElement>> {
    const done = ++this.batchCount > this.numBatches;
    if (done) {
      return {done, value: null};
    }
    if (this.xTensorsFunc == null) {
      // Generate data randomly.
      return {
        done,
        value: done ? null : {
          xs: generateRandomTensorContainer(this.xBatchShape),
          ys: generateRandomTensorContainer(this.yBatchShape)
        }
      };
    } else {
      // Use preset tensors.
      if ((this.batchCount - 1) % this.numBatches === 0) {
        this.xTensorValues = this.xTensorsFunc();
        this.yTensorValues = this.yTensorsFunc();
        this.tensorIndex = 0;
      }
      const index = this.tensorIndex++;

      let xs: tfc.Tensor|{[name: string]: tfc.Tensor};
      if (Array.isArray(this.xTensorValues)) {
        xs = this.xTensorValues[index];
        tfc.util.assert(
            tfc.util.arraysEqual(xs.shape, this.xBatchShape as Shape),
            () => `Shape mismatch: expected: ${
                      JSON.stringify(this.xBatchShape)}; ` +
                `actual: ${JSON.stringify((xs as tfc.Tensor).shape)}`);
      } else {
        xs = {};
        for (const key in this.xTensorValues) {
          xs[key] = this.xTensorValues[key][index];
          tfc.util.assert(
              tfc.util.arraysEqual(xs[key].shape, this.xBatchShape as Shape),
              () => `Shape mismatch: expected: ${
                        JSON.stringify(this.xBatchShape)}; ` +
                  `actual: ${JSON.stringify((xs as tfc.Tensor).shape)}`);
        }
      }

      let ys: tfc.Tensor|{[name: string]: tfc.Tensor};
      if (Array.isArray(this.yTensorValues)) {
        // Get preset ys tensors for single-output models.
        ys = this.yTensorValues[index];
        tfc.util.assert(
            tfc.util.arraysEqual(ys.shape, this.yBatchShape as Shape),
            () => `Shape mismatch: expected: ${
                      JSON.stringify(this.yBatchShape)}; ` +
                `actual: ${JSON.stringify((ys as tfc.Tensor).shape)}`);
      } else {
        // Get preset ys tensors for multi-output models.
        ys = {};
        this.yBatchShape = this.yBatchShape as {[name: string]: Shape};
        for (const key in this.yTensorValues) {
          ys[key] = this.yTensorValues[key][index];
          tfc.util.assert(
              tfc.util.arraysEqual(ys[key].shape, this.yBatchShape[key]),
              () => `Shape mismatch: expected: ${
                        JSON.stringify(this.yBatchShape)}; ` +
                  `actual: ${
                        JSON.stringify(
                            (ys as {[name: string]: tfc.Tensor})[key].shape)}`);
        }
      }

      return {done, value: {xs, ys}};
    }
  }
}

/**
 * A fake dataset with configurable feature and target shapes.
 *
 * The batch size and # of batches are also configurable.
 *
 * The iterator from the dataset always generate random-normal float32 values.
 */
export class FakeNumericDataset extends Dataset<FitDatasetElement> {
  constructor(readonly args: FakeDatasetArgs) {
    super();
    tfc.util.assert(
        args.batchSize > 0 && Number.isInteger(args.batchSize),
        () =>
            `batchSize must be a positive integer, but got ${args.batchSize}`);
    tfc.util.assert(
        args.numBatches > 0 && Number.isInteger(args.numBatches),
        () =>
            `numBatches must be positive integer, but got ${args.numBatches}`);
    this.size = args.numBatches;
  }

  async iterator(): Promise<LazyIterator<FitDatasetElement>> {
    return new FakeNumericIterator(this.args);
  }
}

// We can't use Dataset.map(...) because we don't depend on tfjs-data here,
// so we manually transform the above {xs, ys} dataset to the [xs, ys] form.
export class FakeNumericDatasetLegacyArrayForm extends
    Dataset<[TensorOrArrayOrMap, TensorOrArrayOrMap]> {
  ds: FakeNumericDataset;
  constructor(readonly args: FakeDatasetArgs) {
    super();
    this.ds = new FakeNumericDataset(args);
  }

  async iterator():
      Promise<LazyIterator<[TensorOrArrayOrMap, TensorOrArrayOrMap]>> {
    const it = await this.ds.iterator();
    return new FakeNumericIteratorLegacyArrayForm(it);
  }
}

class FakeNumericIteratorLegacyArrayForm extends
    LazyIterator<[TensorOrArrayOrMap, TensorOrArrayOrMap]> {
  constructor(private readonly it: LazyIterator<FitDatasetElement>) {
    super();
  }

  async next():
      Promise<IteratorResult<[TensorOrArrayOrMap, TensorOrArrayOrMap]>> {
    const result = await this.it.next();
    return {
      done: result.done,
      value: result.value == null ? null : [result.value.xs, result.value.ys]
    };
  }
}
